/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Implements the StreamExecutor interface by passing through to its
// implementation_ value (in pointer-to-implementation style), which
// implements StreamExecutorInterface.

#include "tensorflow/stream_executor/stream_executor_pimpl.h"

#include <atomic>
#include <memory>
#include <utility>

#include "absl/base/const_init.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"

namespace {
bool FLAGS_check_device_leaks = false;
}  // namespace

namespace stream_executor {
namespace {

string StackTraceIfVLOG10() {
  if (VLOG_IS_ON(10)) {
    return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
  } else {
    return "";
  }
}

// Make sure the executor is done with its work; we know (because this isn't
// publicly visible) that all enqueued work is quick.
void BlockOnThreadExecutor(port::ThreadPool *executor) {
  absl::Notification n;
  executor->Schedule([&n]() { n.Notify(); });
  n.WaitForNotification();
}

std::atomic_int_fast64_t correlation_id_generator(0);

}  // namespace

template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
          typename... BeginArgsT>
class ScopedTracer {
 public:
  ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
               CompleteCallT complete_call, const ReturnT *result,
               BeginArgsT... begin_args)
      : stream_exec_(stream_exec),
        complete_call_(complete_call),
        result_(result) {
    if (stream_exec_->tracing_enabled_) {
      correlation_id_ =
          correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
      Trace(begin_call, begin_args...);
    }
  }

  ~ScopedTracer() {
    if (stream_exec_->tracing_enabled_) {
      Trace(complete_call_, result_);
    }
  }

 private:
  template <typename CallbackT, typename... TraceArgsT>
  void Trace(CallbackT callback, TraceArgsT... args) {
    {
      // Instance tracers held in a block to limit the lock lifetime.
      absl::ReaderMutexLock lock{&stream_exec_->mu_};
      for (TraceListener *listener : stream_exec_->listeners_) {
        (listener->*callback)(correlation_id_,
                              std::forward<TraceArgsT>(args)...);
      }
    }
  }

  StreamExecutor *stream_exec_;
  CompleteCallT complete_call_;
  const ReturnT *result_;
  int64 correlation_id_;
};

template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
          typename... BeginArgsT>
ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
                 CompleteCallT complete_call, ReturnT *result,
                 BeginArgsT... begin_args) {
  return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
      stream_exec, begin_call, complete_call, result,
      std::forward<BeginArgsT>(begin_args)...);
}

#define SCOPED_TRACE(LOC, ...) \
  auto tracer =                \
      MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);

/* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};

// Get per-device memory limit in bytes. Returns 0 if
// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
static int64 GetMemoryLimitBytes() {
  int64 value;
  SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
                                              0, &value));
  return value * (1ll << 20);
}

StreamExecutor::StreamExecutor(
    const Platform *platform,
    std::unique_ptr<internal::StreamExecutorInterface> implementation,
    int device_ordinal)
    : platform_(platform),
      implementation_(std::move(implementation)),
      device_ordinal_(device_ordinal),
      background_threads_(new port::ThreadPool(
          port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
      live_stream_count_(0),
      tracing_enabled_(false),
      mem_alloc_bytes_(0),
      memory_limit_bytes_(GetMemoryLimitBytes()),
      allocator_(this) {
  string name = absl::AsciiStrToLower(platform_->Name());
  if (name == "cuda") {
    platform_kind_ = PlatformKind::kCuda;
  } else if (name == "rocm") {
    platform_kind_ = PlatformKind::kROCm;
  } else if (name == "opencl") {
    platform_kind_ = PlatformKind::kOpenCL;
  } else if (name == "host") {
    platform_kind_ = PlatformKind::kHost;
  } else {
    platform_kind_ = PlatformKind::kInvalid;
  }
}

StreamExecutor::~StreamExecutor() {
  BlockOnThreadExecutor(background_threads_.get());

  if (live_stream_count_.load() != 0) {
    LOG(WARNING) << "Not all streams were deallocated at executor destruction "
                 << "time. This may lead to unexpected/bad behavior - "
                 << "especially if any stream is still active!";
  }

  if (FLAGS_check_device_leaks) {
    for (auto it : mem_allocs_) {
      LOG(INFO) << "Memory alloced at executor exit: addr: "
                << absl::StrFormat("%p", it.first)
                << ", bytes: " << it.second.bytes << ", trace: \n"
                << it.second.stack_trace;
    }
  }
}

port::Status StreamExecutor::Init(DeviceOptions device_options) {
  return implementation_->Init(device_ordinal_, std::move(device_options));
}

port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }

port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
                                       KernelBase *kernel) {
  return implementation_->GetKernel(spec, kernel);
}

void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
  implementation_->UnloadKernel(kernel);
}

port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
                                        ModuleHandle *module_handle) {
  return implementation_->LoadModule(spec, module_handle);
}

bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
  return implementation_->UnloadModule(module_handle);
}

void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
  VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
          << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();

  if (mem->opaque() != nullptr) {
    EraseAllocRecord(mem->opaque());
  }
  implementation_->Deallocate(mem);
  mem->Reset(nullptr, 0);
}

void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
  absl::ReaderMutexLock lock(&mu_);
  *records_out = mem_allocs_;
}

bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
  return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
}

port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
  return implementation_->EnablePeerAccessTo(other->implementation_.get());
}

SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
  return implementation_->GetDeviceSharedMemoryConfig();
}

port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
    SharedMemoryConfig config) {
  if (config != SharedMemoryConfig::kDefault &&
      config != SharedMemoryConfig::kFourByte &&
      config != SharedMemoryConfig::kEightByte) {
    string error_msg = absl::StrFormat(
        "Invalid shared memory config specified: %d", static_cast<int>(config));
    LOG(ERROR) << error_msg;
    return port::Status(port::error::INVALID_ARGUMENT, error_msg);
  }
  return implementation_->SetDeviceSharedMemoryConfig(config);
}

const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
  absl::MutexLock lock(&mu_);
  if (device_description_ != nullptr) {
    return *device_description_;
  }

  device_description_ = CreateDeviceDescription();
  return *device_description_;
}

int64 StreamExecutor::GetDeviceLoad() const {
  return implementation_->GetDeviceLoad();
}

int StreamExecutor::PlatformDeviceCount() const {
  return implementation_->PlatformDeviceCount();
}

bool StreamExecutor::SupportsBlasPlans() const {
  return implementation_->SupportsBlasPlans();
}

bool StreamExecutor::SupportsBlas() const {
  return implementation_->SupportsBlas();
}

bool StreamExecutor::SupportsRng() const {
  return implementation_->SupportsRng();
}

bool StreamExecutor::SupportsDnn() const {
  return implementation_->SupportsDnn();
}

bool StreamExecutor::GetConvolveAlgorithms(
    bool with_winograd_nonfused,
    std::vector<dnn::AlgorithmDesc> *out_algorithms) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  int cc_major, cc_minor;
  GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
  return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
                                            cc_minor, out_algorithms);
}

bool StreamExecutor::GetConvolveExecutionPlans(
    dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
    const dnn::BatchDescriptor &input_descriptor,
    const dnn::FilterDescriptor &filter_descriptor,
    const dnn::BatchDescriptor &output_descriptor,
    const dnn::ConvolutionDescriptor &convolution_descriptor,
    std::vector<cudnn_frontend::ExecutionPlan> *out_exec_plans) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  return dnn_support->GetConvolveExecutionPlans(
      kind, element_type, stream, input_descriptor, filter_descriptor,
      output_descriptor, convolution_descriptor, out_exec_plans);
}

bool StreamExecutor::GetFusedConvolveExecutionPlans(
    dnn::ConvolutionKind kind, dnn::DataType element_type, Stream *stream,
    const dnn::BatchDescriptor &input_descriptor,
    const dnn::FilterDescriptor &filter_descriptor,
    const dnn::BatchDescriptor &bias_descriptor,
    const dnn::BatchDescriptor &output_descriptor,
    const dnn::ConvolutionDescriptor &convolution_descriptor,
    std::vector<cudnn_frontend::ExecutionPlan> *out_exec_plans) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  return dnn_support->GetFusedConvolveExecutionPlans(
      kind, element_type, stream, input_descriptor, filter_descriptor,
      bias_descriptor, output_descriptor, convolution_descriptor,
      out_exec_plans);
}

bool StreamExecutor::GetRnnAlgorithms(
    std::vector<dnn::AlgorithmDesc> *out_algorithms) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  return dnn_support->GetRnnAlgorithms(out_algorithms);
}

bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
    bool with_winograd_nonfused,
    std::vector<dnn::AlgorithmDesc> *out_algorithms) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  int cc_major, cc_minor;
  GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
  return dnn_support->GetConvolveBackwardDataAlgorithms(
      with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}

bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
    bool with_winograd_nonfused,
    std::vector<dnn::AlgorithmDesc> *out_algorithms) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return false;
  }
  int cc_major, cc_minor;
  GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
  return dnn_support->GetConvolveBackwardFilterAlgorithms(
      with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}

bool StreamExecutor::GetBlasGemmAlgorithms(
    std::vector<blas::AlgorithmType> *out_algorithms) {
  blas::BlasSupport *blas_support = AsBlas();
  if (!blas_support) {
    return false;
  }
  return blas_support->GetBlasGemmAlgorithms(out_algorithms);
}

port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
StreamExecutor::CreateBlasLtMatmulPlan(
    const blas::BlasLtMatmulPlanParams& params) {
  blas::BlasSupport* blas_support = AsBlas();
  if (!blas_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the blas implementation.");
  }
  return blas_support->CreateBlasLtMatmulPlan(params);
}

port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
                                          size_t max_workspace_size,
                                          int max_algorithm_count) {
  blas::BlasSupport* blas_support = AsBlas();
  if (!blas_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the blas implementation.");
  }
  return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
                                                 max_algorithm_count);
}

port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
    int num_layers, int hidden_size, int input_size, int cell_size,
    int batch_size, dnn::RnnInputMode input_mode,
    dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
    dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
    float dropout, uint64 seed, ScratchAllocator *state_allocator,
    bool use_padded_io) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the dnn implementation.");
  }
  return dnn_support->createRnnDescriptor(
      num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
      direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
      state_allocator, use_padded_io);
}

port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
StreamExecutor::createCtcLossDescriptor(dnn::DataType data_type) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the dnn implementation.");
  }
  return dnn_support->createCtcLossDescriptor(data_type);
}

port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
                                                  int batch_size, int data_size,
                                                  dnn::DataType data_type) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the dnn implementation.");
  }
  return dnn_support->createRnnSequenceTensorDescriptor(
      max_seq_length, batch_size, data_size, data_type);
}

port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
StreamExecutor::createRnnSequenceTensorDescriptor(
    int max_seq_length, int batch_size, int data_size,
    const absl::Span<const int> &seq_lengths, bool time_major,
    dnn::DataType data_type) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the dnn implementation.");
  }
  return dnn_support->createRnnSequenceTensorDescriptor(
      max_seq_length, batch_size, data_size, seq_lengths, time_major,
      data_type);
}

port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
                                               int data_size,
                                               dnn::DataType data_type) {
  dnn::DnnSupport *dnn_support = AsDnn();
  if (!dnn_support) {
    return port::Status(port::error::UNKNOWN,
                        "Fail to find the dnn implementation.");
  }
  return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
                                                     data_size, data_type);
}

dnn::DnnSupport *StreamExecutor::AsDnn() {
  absl::MutexLock lock(&mu_);
  if (dnn_ != nullptr) {
    return dnn_.get();
  }

  dnn_.reset(implementation_->CreateDnn());
  return dnn_.get();
}

blas::BlasSupport *StreamExecutor::AsBlas() {
  absl::MutexLock lock(&mu_);
  if (blas_ != nullptr) {
    return blas_.get();
  }

  blas_.reset(implementation_->CreateBlas());
  return blas_.get();
}

fft::FftSupport *StreamExecutor::AsFft() {
  absl::MutexLock lock(&mu_);
  if (fft_ != nullptr) {
    return fft_.get();
  }

  fft_.reset(implementation_->CreateFft());
  return fft_.get();
}

rng::RngSupport *StreamExecutor::AsRng() {
  absl::MutexLock lock(&mu_);
  if (rng_ != nullptr) {
    return rng_.get();
  }

  rng_.reset(implementation_->CreateRng());
  return rng_.get();
}

port::Status StreamExecutor::Launch(Stream *stream,
                                    const ThreadDim &thread_dims,
                                    const BlockDim &block_dims,
                                    const KernelBase &kernel,
                                    const KernelArgsArrayBase &args) {
  SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
              kernel, args);

  return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
}

port::Status StreamExecutor::LaunchExecutableGraph(Stream *stream,
                                                   void *exec_graph) {
  return implementation_->LaunchExecutableGraph(stream, exec_graph);
}

port::Status StreamExecutor::BeginGraphCapture(Stream *stream) {
  return implementation_->BeginGraphCapture(stream);
}

port::StatusOr<void *> StreamExecutor::EndGraphCapture(Stream *stream,
                                                       void *graph) {
  return implementation_->EndGraphCapture(stream, graph);
}

port::StatusOr<void *> StreamExecutor::InstantiateGraph(void *graph,
                                                        void *graph_exec) {
  return implementation_->InstantiateGraph(graph, graph_exec);
}

port::Status StreamExecutor::UpdateExecutableGraph(void *graph,
                                                   void *graph_exec) {
  return implementation_->UpdateExecutableGraph(graph, graph_exec);
}

void StreamExecutor::DestroyExecutableGraph(void *context, void *exec_graph) {
  implementation_->DestroyExecutableGraph(context, exec_graph);
}

void StreamExecutor::DestroyGraph(void *context, void *graph) {
  implementation_->DestroyExecutableGraph(context, graph);
}

port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
  port::Status result;
  SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);

  result = implementation_->BlockHostUntilDone(stream);
  return result;
}

port::Status StreamExecutor::GetStatus(Stream *stream) {
  return implementation_->GetStatus(stream);
}

DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
  if (memory_limit_bytes_ > 0 &&
      mem_alloc_bytes_ + size > memory_limit_bytes_) {
    LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
                 << device_ordinal_
                 << " within provided limit. [used=" << mem_alloc_bytes_
                 << ", limit=" << memory_limit_bytes_ << "]";
    return DeviceMemoryBase();
  }
  DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
  VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
          << ", memory_space=" << memory_space << ") returns " << buf.opaque()
          << StackTraceIfVLOG10();
  CreateAllocRecord(buf.opaque(), size);

  return buf;
}

port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
    const string &symbol_name, ModuleHandle module_handle) {
  // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
  // be nullptr/0 for consistency with DeviceMemory semantics.
  void *opaque = nullptr;
  size_t bytes = 0;
  if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
    return DeviceMemoryBase(opaque, bytes);
  }

  if (static_cast<bool>(module_handle)) {
    return port::Status(
        port::error::NOT_FOUND,
        absl::StrCat("Check if module containing symbol ", symbol_name,
                     " is loaded (module_handle = ",
                     reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
  } else {
    return port::Status(
        port::error::NOT_FOUND,
        absl::StrCat("Check if kernel using the symbol is loaded: ",
                     symbol_name));
  }
}

bool StreamExecutor::GetSymbol(const string &symbol_name,
                               ModuleHandle module_handle, void **mem,
                               size_t *bytes) {
  return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
}

void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
  void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
  VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
          << ") returns " << buffer << StackTraceIfVLOG10();
  return buffer;
}

void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
  VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
          << location << ")" << StackTraceIfVLOG10();

  return implementation_->UnifiedMemoryDeallocate(location);
}

void *StreamExecutor::HostMemoryAllocate(uint64 size) {
  void *buffer = implementation_->HostMemoryAllocate(size);
  VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
          << ") returns " << buffer << StackTraceIfVLOG10();
  return buffer;
}

void StreamExecutor::HostMemoryDeallocate(void *location) {
  VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
          << ")" << StackTraceIfVLOG10();

  return implementation_->HostMemoryDeallocate(location);
}

bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
  VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
          << ", size=" << size << ")" << StackTraceIfVLOG10();
  if (location == nullptr || size == 0) {
    LOG(WARNING) << "attempting to register null or zero-sized memory: "
                 << location << "; size " << size;
  }
  return implementation_->HostMemoryRegister(location, size);
}

bool StreamExecutor::HostMemoryUnregister(void *location) {
  VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
          << ")" << StackTraceIfVLOG10();
  return implementation_->HostMemoryUnregister(location);
}

bool StreamExecutor::SynchronizeAllActivity() {
  VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
          << StackTraceIfVLOG10();
  bool ok = implementation_->SynchronizeAllActivity();

  // This should all be quick and infallible work, so we can perform the
  // synchronization even in the case of failure.
  BlockOnThreadExecutor(background_threads_.get());

  return ok;
}

port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
                                                uint64 size) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
          << ", size=" << size << ")" << StackTraceIfVLOG10();

  return implementation_->SynchronousMemZero(location, size);
}

port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location,
                                               int value, uint64 size) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
          << ", value=" << value << ", size=" << size << ")"
          << StackTraceIfVLOG10();

  return implementation_->SynchronousMemSet(location, value, size);
}

bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
                                       const void *host_src, uint64 size) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
          << device_dst->opaque() << ", host_src=" << host_src
          << ", size=" << size << ") H2D" << StackTraceIfVLOG10();

  // Tracing overloaded methods is very difficult due to issues with type
  // inference on template args. Since use of these overloaded methods is
  // discouraged anyway, this isn't a huge deal.
  port::Status status =
      implementation_->SynchronousMemcpy(device_dst, host_src, size);
  if (!status.ok()) {
    LOG(ERROR) << "synchronous memcpy: " << status;
  }
  return status.ok();
}

bool StreamExecutor::SynchronousMemcpy(void *host_dst,
                                       const DeviceMemoryBase &device_src,
                                       uint64 size) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
          << ", device_src=" << device_src.opaque() << ", size=" << size
          << ") D2H" << StackTraceIfVLOG10();

  port::Status status =
      implementation_->SynchronousMemcpy(host_dst, device_src, size);
  if (!status.ok()) {
    LOG(ERROR) << "synchronous memcpy: " << status;
  }
  return status.ok();
}

bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
                                       const DeviceMemoryBase &device_src,
                                       uint64 size) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
          << device_dst->opaque() << ", device_src=" << device_src.opaque()
          << ", size=" << size << ") D2D" << StackTraceIfVLOG10();

  port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
      device_dst, device_src, size);
  if (!status.ok()) {
    LOG(ERROR) << "synchronous memcpy: " << status;
  }
  return status.ok();
}

port::Status StreamExecutor::SynchronousMemcpyD2H(
    const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
          << device_src.opaque() << ", size=" << size
          << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();

  port::Status result;
  SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
               host_dst);

  result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
  if (!result.ok()) {
    result = port::Status(
        port::error::INTERNAL,
        absl::StrFormat("failed to synchronously memcpy device-to-host: device "
                        "%p to host %p size %d: %s",
                        device_src.opaque(), host_dst, size,
                        result.ToString()));
  }

  return result;
}

port::Status StreamExecutor::SynchronousMemcpyH2D(
    const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
  VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
          << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
          << StackTraceIfVLOG10();

  port::Status result;
  SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
               device_dst);

  result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
  if (!result.ok()) {
    result = port::Status(
        port::error::INTERNAL,
        absl::StrFormat("failed to synchronously memcpy host-to-device: host "
                        "%p to device %p size %d: %s",
                        host_src, device_dst->opaque(), size,
                        result.ToString()));
  }

  return result;
}

bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
                            const DeviceMemoryBase &device_src, uint64 size) {
  return implementation_->Memcpy(stream, host_dst, device_src, size);
}

bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
                            const void *host_src, uint64 size) {
  return implementation_->Memcpy(stream, device_dst, host_src, size);
}

bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
                                          DeviceMemoryBase *device_dst,
                                          const DeviceMemoryBase &device_src,
                                          uint64 size) {
  return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
                                               size);
}

port::Status StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
                                     uint64 size) {
  return implementation_->MemZero(stream, location, size);
}

port::Status StreamExecutor::Memset32(Stream *stream,
                                      DeviceMemoryBase *location,
                                      uint32 pattern, uint64 size) {
  CHECK_EQ(0, size % 4)
      << "need 32-bit multiple size to fill with 32-bit pattern";
  return implementation_->Memset32(stream, location, pattern, size);
}

bool StreamExecutor::HostCallback(Stream *stream,
                                  std::function<void()> callback) {
  return implementation_->HostCallback(stream, std::move(callback));
}

bool StreamExecutor::HostCallback(Stream *stream,
                                  std::function<port::Status()> callback) {
  return implementation_->HostCallback(stream, std::move(callback));
}

port::Status StreamExecutor::AllocateEvent(Event *event) {
  return implementation_->AllocateEvent(event);
}

port::Status StreamExecutor::DeallocateEvent(Event *event) {
  return implementation_->DeallocateEvent(event);
}

port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
  return implementation_->RecordEvent(stream, event);
}

port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
  return implementation_->WaitForEvent(stream, event);
}

Event::Status StreamExecutor::PollForEventStatus(Event *event) {
  return implementation_->PollForEventStatus(event);
}

bool StreamExecutor::AllocateStream(Stream *stream) {
  live_stream_count_.fetch_add(1, std::memory_order_relaxed);
  if (!implementation_->AllocateStream(stream)) {
    auto count = live_stream_count_.fetch_sub(1);
    CHECK_GE(count, 0) << "live stream count should not dip below zero";
    LOG(INFO) << "failed to allocate stream; live stream count: " << count;
    return false;
  }

  return true;
}

void StreamExecutor::DeallocateStream(Stream *stream) {
  implementation_->DeallocateStream(stream);
  CHECK_GE(live_stream_count_.fetch_sub(1), 0)
      << "live stream count should not dip below zero";
}

bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
  return implementation_->CreateStreamDependency(dependent, other);
}

bool StreamExecutor::AllocateTimer(Timer *timer) {
  return implementation_->AllocateTimer(timer);
}

void StreamExecutor::DeallocateTimer(Timer *timer) {
  return implementation_->DeallocateTimer(timer);
}

bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
  return implementation_->StartTimer(stream, timer);
}

bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
  return implementation_->StopTimer(stream, timer);
}

std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
    const {
  auto desc_status = implementation_->CreateDeviceDescription();
  return desc_status.ConsumeValueOrDie();
}

bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
  return implementation_->DeviceMemoryUsage(free, total);
}

void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
  background_threads_->Schedule(std::move(task));
}

void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
  if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
    absl::MutexLock lock(&mu_);
    mem_allocs_[opaque] = AllocRecord{bytes, ""};
    mem_alloc_bytes_ += bytes;
  }
}

void StreamExecutor::EraseAllocRecord(void *opaque) {
  if (FLAGS_check_device_leaks && opaque != nullptr) {
    absl::MutexLock lock(&mu_);
    if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
      LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
    } else {
      mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
      mem_allocs_.erase(opaque);
    }
  }
}

void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }

void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
  {
    absl::MutexLock lock(&mu_);
    if (listeners_.find(listener) != listeners_.end()) {
      LOG(INFO) << "Attempt to register already-registered listener, "
                << listener;
    } else {
      listeners_.insert(listener);
    }
  }

  implementation_->RegisterTraceListener(listener);
}

bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
  {
    absl::MutexLock lock(&mu_);
    if (listeners_.find(listener) == listeners_.end()) {
      LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
      return false;
    }
    listeners_.erase(listener);
  }

  implementation_->UnregisterTraceListener(listener);
  return true;
}

absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
  return implementation_->GetAllocatorStats();
}

template <typename TraceCallT, typename... ArgsT>
void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
  if (tracing_enabled_) {
    {
      // instance tracers held in a block to limit the lock lifetime.
      absl::ReaderMutexLock lock(&mu_);
      for (TraceListener *listener : listeners_) {
        (listener->*trace_call)(std::forward<ArgsT>(args)...);
      }
    }
  }
}

internal::StreamExecutorInterface *StreamExecutor::implementation() {
  return implementation_->GetUnderlyingExecutor();
}

StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
    StreamExecutor *executor)
    : DeviceMemoryAllocator(executor->platform()) {
  stream_executors_ = {executor};
}

StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
    const Platform *platform,
    absl::Span<StreamExecutor *const> stream_executors)
    : DeviceMemoryAllocator(platform),
      stream_executors_(stream_executors.begin(), stream_executors.end()) {}

port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
    int device_ordinal, uint64 size, bool retry_on_failure,
    int64 memory_space) {
  TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
                      GetStreamExecutor(device_ordinal));
  DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
  if (size > 0 && result == nullptr) {
    return tensorflow::errors::ResourceExhausted(absl::StrFormat(
        "Failed to allocate request for %s (%uB) on device ordinal %d",
        tensorflow::strings::HumanReadableNumBytes(size), size,
        device_ordinal));
  }
  VLOG(3) << absl::StreamFormat(
      "Allocated %s (%uB) on device ordinal %d: %p",
      tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
      result.opaque());
  return OwningDeviceMemory(result, device_ordinal, this);
}

port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
                                                       DeviceMemoryBase mem) {
  if (!mem.is_null()) {
    TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
                        GetStreamExecutor(device_ordinal));
    VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
                                  mem.opaque(), device_ordinal);
    executor->Deallocate(&mem);
  }
  return port::Status::OK();
}

port::StatusOr<StreamExecutor *>
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
  if (device_ordinal < 0) {
    return tensorflow::errors::InvalidArgument(absl::StrFormat(
        "device ordinal value (%d) must be non-negative", device_ordinal));
  }
  for (StreamExecutor *se : stream_executors_) {
    if (se->device_ordinal() == device_ordinal) {
      return se;
    }
  }
  return tensorflow::errors::NotFound(
      absl::StrFormat("Device %s:%d present but not supported",
                      platform()->Name(), device_ordinal));
}

bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
  return false;
}

port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
    int device_ordinal) {
  CHECK(!AllowsAsynchronousDeallocation())
      << "The logic below only works for synchronous allocators";
  TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
                      GetStreamExecutor(device_ordinal));
  Stream *out = [&] {
    absl::MutexLock lock(&mutex_);
    if (!streams_.count(device_ordinal)) {
      auto p = streams_.emplace(std::piecewise_construct,
                                std::forward_as_tuple(device_ordinal),
                                std::forward_as_tuple(executor));
      p.first->second.Init();
      return &p.first->second;
    }
    return &streams_.at(device_ordinal);
  }();
  return out;
}

}  // namespace stream_executor
